{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bf0022a3-6769-4a28-a163-d5408344b23d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scHash\n",
    "import anndata as ad\n",
    "\n",
    "from sklearn.metrics import f1_score, precision_score, recall_score\n",
    "from statistics import median"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "861967af-edc6-47be-8a2f-f66c6c5ebbc1",
   "metadata": {},
   "source": [
    "# Tutorial for scATAC-seq cell annotations\n",
    "Here is an demonstration on scATAC-seq dataset. For ATAC data, we used the mapping method adapted from [GLUE](https://www.nature.com/articles/s41587-022-01284-4) to map peaks to genes. The dataset can also be obtained from the GLUE github page. We used Muto-2021 for demonstration."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67da14ca-0fb5-4b5d-8599-1cd4384044cb",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bf3f0bd2-3658-44d7-8ea4-c0e13e5309ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define data path\n",
    "data_dir = 'Muto-2021-FRAGS2RNA.h5ad'\n",
    "\n",
    "# This data contains both the reference and query source\n",
    "data = ad.read_h5ad(data_dir)\n",
    "\n",
    "# we randomly picked '5028f75a-8c09-4155-a232-ad7dbfa6042e' as query\n",
    "query='5028f75a-8c09-4155-a232-ad7dbfa6042e'\n",
    "train = data[data.obs.batch!=query]\n",
    "test = data[data.obs.batch==query]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b6cb2c4-7d29-4d06-ad8e-4ed36e2829d2",
   "metadata": {},
   "source": [
    "## Training Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85ade212-ca91-43c7-8182-9c4cbf248e58",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# set up the training datamodule\n",
    "datamodule = scHash.setup_training_data(train_data = train,cell_type_key = 'cell_type', batch_key = 'batch')\n",
    "\n",
    "# set a directory to save the model \n",
    "checkpointPath = '../checkpoint/'\n",
    "\n",
    "# initiliza scHash model and train \n",
    "model = scHash.scHashModel(datamodule)\n",
    "trainer, best_model_path, training_time = scHash.training(model = model, datamodule = datamodule, checkpointPath = checkpointPath, max_epochs = 70)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4b826ab-1ffb-494e-8037-9a02a6407e3d",
   "metadata": {},
   "source": [
    "## Test Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fdecaada-8e97-46b5-94c9-6807c734061f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1 Median: 0.965\n"
     ]
    }
   ],
   "source": [
    "# add the test data\n",
    "datamodule.setup_test_data(test)\n",
    "\n",
    "# test the model\n",
    "pred_labels, hash_codes = scHash.testing(trainer, model, best_model_path)\n",
    "\n",
    "# show the test performance\n",
    "labels_true = test.obs.cell_type\n",
    "f1_median = round(median(f1_score(labels_true,pred_labels,average=None)),3)\n",
    "\n",
    "print(f'F1 Median: {f1_median}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.7 (v3.8.7:6503f05dd5, Dec 21 2020, 12:45:15) \n[Clang 6.0 (clang-600.0.57)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
